/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_MATRIX_MATMUL_PRELOAD_CUSTOM_TILING_H
#define EXAMPLES_MATRIX_MATMUL_PRELOAD_CUSTOM_TILING_H
#include "kernel_operator.h"
#define ASCENDC_CUBE_ONLY
#include "lib/matmul_intf.h"

namespace CustomMatmulPreload {
template <typename aType, typename bType, typename cType, typename biasType, int32_t preloadMode>
class MatmulPreloadKernel {
    public:
        __aicore__ inline MatmulPreloadKernel(){};
        __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace,
            const TCubeTiling& tiling);
        __aicore__ inline void Process(AscendC::TPipe* pipe);
        static constexpr MatmulConfig MM_CFG = GetMDLConfig(false, false, preloadMode);  // enable preload M/N
        AscendC::Matmul<AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, cType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, MM_CFG> matmulObj;

    private:
        __aicore__ inline void CalcOffset(
            int32_t blockIdx, TCubeTiling& param, int32_t& offsetA, int32_t& offsetB,
            int32_t& offsetC, int32_t& offsetBias);

        AscendC::GlobalTensor<aType> aGlobal;
        AscendC::GlobalTensor<bType> bGlobal;
        AscendC::GlobalTensor<cType> cGlobal;
        AscendC::GlobalTensor<biasType> biasGlobal;
        TCubeTiling tiling;
};

template <typename aType, typename bType, typename cType, typename biasType, int32_t preloadMode>
__aicore__ inline void MatmulPreloadKernel<aType, bType, cType, biasType, preloadMode>::Init(GM_ADDR a,
        GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling)
{
    this->tiling = tiling;
    aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), tiling.M * tiling.Ka);
    bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), tiling.Kb * tiling.N);
    cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), tiling.M * tiling.N);
    biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), tiling.N);

    int32_t offsetA = 0;
    int32_t offsetB = 0;
    int32_t offsetC = 0;
    int32_t offsetBias = 0;
    CalcOffset(AscendC::GetBlockIdx(), this->tiling, offsetA, offsetB, offsetC, offsetBias);
    aGlobal = aGlobal[offsetA];
    bGlobal = bGlobal[offsetB];
    cGlobal = cGlobal[offsetC];
    biasGlobal = biasGlobal[offsetBias];
    if(GetSysWorkSpacePtr() == nullptr){
        return;
    }
}

template <typename aType, typename bType, typename cType, typename biasType, int32_t preloadMode>
__aicore__ inline void MatmulPreloadKernel<aType, bType, cType, biasType, preloadMode>::Process(AscendC::TPipe* pipe)
{
    REGIST_MATMUL_OBJ(pipe, GetSysWorkSpacePtr(), matmulObj, &(this->tiling));
    matmulObj.SetTensorA(aGlobal, false);
    matmulObj.SetTensorB(bGlobal, false);
    if (tiling.isBias) {
        matmulObj.SetBias(biasGlobal);
    }
    matmulObj.IterateAll(cGlobal);
    matmulObj.End();
}


template <typename aType, typename bType, typename cType, typename biasType, int32_t preloadMode>
__aicore__ inline void MatmulPreloadKernel<aType, bType, cType, biasType, preloadMode>::CalcOffset(
    int32_t blockIdx, TCubeTiling& param,
    int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
    auto temp0 = AscendC::Ceil(param.M, param.singleCoreM);
    auto temp1 = AscendC::Ceil(param.N, param.singleCoreN);
    auto temp2 = AscendC::Ceil(param.Ka, param.singleCoreK);

    auto divideKCoreNum = param.usedCoreNum / temp2;

    auto mCoreIndex = (blockIdx % divideKCoreNum) % temp0;
    auto nCoreIndex = (blockIdx % divideKCoreNum) / temp0;
    auto subKIndex = blockIdx / divideKCoreNum;

    offsetA = mCoreIndex * param.Ka * param.singleCoreM + subKIndex * param.singleCoreK;
    offsetB = subKIndex * param.singleCoreK * param.N + nCoreIndex * param.singleCoreN;
    offsetC = mCoreIndex * param.N * param.singleCoreM + nCoreIndex * param.singleCoreN;
    offsetBias = nCoreIndex * param.singleCoreN;

    int32_t gmUseM = param.M - mCoreIndex * param.singleCoreM;
    param.singleCoreM = gmUseM < param.singleCoreM ? gmUseM : param.singleCoreM;

    int32_t gmUseN = param.N - nCoreIndex * param.singleCoreN;
    param.singleCoreN = gmUseN < param.singleCoreN ? gmUseN : param.singleCoreN;

    int32_t gmUseK = param.Ka - subKIndex * param.singleCoreK;
    param.singleCoreK = gmUseK < param.singleCoreK ? gmUseK : param.singleCoreK;
}
}  // namespace CustomMatmulPreload
#endif // EXAMPLES_MATRIX_MATMUL_PRELOAD_CUSTOM_TILING_H